# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.tf.env import (
    tf,
)
from deepmd.tf.utils.type_embed import (
    TypeEmbedNet,
    embed_atom_type,
)


class TestTypeEbd(tf.test.TestCase):
    def test_embed_atom_type(self) -> None:
        ntypes = 3
        natoms = tf.constant([5, 5, 3, 0, 2])
        type_embedding = tf.constant(
            [
                [1, 2, 3],
                [3, 2, 1],
                [7, 7, 7],
            ]
        )
        expected_out = [[1, 2, 3], [1, 2, 3], [1, 2, 3], [7, 7, 7], [7, 7, 7]]
        atom_embed = embed_atom_type(ntypes, natoms, type_embedding)
        sess = self.cached_session().__enter__()
        atom_embed = sess.run(atom_embed)
        np.testing.assert_almost_equal(atom_embed, expected_out, 10)

    def test_type_embed_net(self) -> None:
        ten = TypeEmbedNet(
            ntypes=2, neuron=[2, 4, 8], seed=1, uniform_seed=True, use_tebd_bias=True
        )
        type_embedding = ten.build(2)
        sess = self.cached_session().__enter__()
        sess.run(tf.global_variables_initializer())
        type_embedding = sess.run(type_embedding)

        expected_out = [
            1.429967002262267917e00,
            -9.138175897677495163e-01,
            -3.799606588218059633e-01,
            -2.143157692726757046e-01,
            2.341138114260268743e00,
            -1.568346043255314015e00,
            8.917082000854256174e-01,
            -1.500356675378008209e00,
            8.955885646123034061e-01,
            -5.835326470989941061e-01,
            -1.465708662924672057e00,
            -4.052047884085572260e-01,
            1.367825594590430072e00,
            -2.736204307656463497e-01,
            -4.044263041521370394e-01,
            -9.438057524881729998e-01,
        ]
        expected_out = np.reshape(expected_out, [2, 8])

        # 2 types
        self.assertEqual(type_embedding.shape[0], 2)
        # size of embedded vec 8
        self.assertEqual(type_embedding.shape[1], 8)
        # check value
        np.testing.assert_almost_equal(type_embedding, expected_out, 10)
